{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HZiF5lbumA7j"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KsOkK8O69PyT"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eNj0_BTFk479"
      },
      "source": [
        "# TF Lattice 预制模型"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T3qE8F5toE28"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/lattice/tutorials/premade_models\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\">在 TensorFlow.org 上查看 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/lattice/tutorials/premade_models.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">Run in Google Colab</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/lattice/tutorials/premade_models.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 GitHub 中查看源代码</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/lattice/tutorials/premade_models.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\"> 下载笔记本</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HEuRMAUOlFZa"
      },
      "source": [
        "## 概述\n",
        "\n",
        "利用预制模型，您可以快速轻松地针对典型用例构建 TFL `tf.keras.model` 实例。本指南概述了构造 TFL 预制模型并对其进行训练/测试所需的步骤。 "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f2--Yq21lhRe"
      },
      "source": [
        "## 设置\n",
        "\n",
        "安装 TF Lattice 软件包："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XizqBCyXky4y"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install tensorflow-lattice pydot"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2oKJPy5tloOB"
      },
      "source": [
        "导入所需的软件包："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9wZWJJggk4al"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import logging\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sys\n",
        "import tensorflow_lattice as tfl\n",
        "logging.disable(sys.maxsize)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kpJJSS7YmLbG"
      },
      "source": [
        "下载 UCI Statlog (Heart) 数据集："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AYTcybljmQJm"
      },
      "outputs": [],
      "source": [
        "csv_file = tf.keras.utils.get_file(\n",
        "    'heart.csv', 'http://storage.googleapis.com/applied-dl/heart.csv')\n",
        "df = pd.read_csv(csv_file)\n",
        "train_size = int(len(df) * 0.8)\n",
        "train_dataframe = df[:train_size]\n",
        "test_dataframe = df[train_size:]\n",
        "df.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ODe0oavWmtAi"
      },
      "source": [
        "提取特征和标签并将它们转换为张量："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3ae-Mx-PnGGL"
      },
      "outputs": [],
      "source": [
        "# Features:\n",
        "# - age\n",
        "# - sex\n",
        "# - cp        chest pain type (4 values)\n",
        "# - trestbps  resting blood pressure\n",
        "# - chol      serum cholestoral in mg/dl\n",
        "# - fbs       fasting blood sugar &gt; 120 mg/dl\n",
        "# - restecg   resting electrocardiographic results (values 0,1,2)\n",
        "# - thalach   maximum heart rate achieved\n",
        "# - exang     exercise induced angina\n",
        "# - oldpeak   ST depression induced by exercise relative to rest\n",
        "# - slope     the slope of the peak exercise ST segment\n",
        "# - ca        number of major vessels (0-3) colored by flourosopy\n",
        "# - thal      3 = normal; 6 = fixed defect; 7 = reversable defect\n",
        "#\n",
        "# This ordering of feature names will be the exact same order that we construct\n",
        "# our model to expect.\n",
        "feature_names = [\n",
        "    'age', 'sex', 'cp', 'chol', 'fbs', 'trestbps', 'thalach', 'restecg',\n",
        "    'exang', 'oldpeak', 'slope', 'ca', 'thal'\n",
        "]\n",
        "feature_name_indices = {name: index for index, name in enumerate(feature_names)}\n",
        "# This is the vocab list and mapping we will use for the 'thal' categorical\n",
        "# feature.\n",
        "thal_vocab_list = ['normal', 'fixed', 'reversible']\n",
        "thal_map = {category: i for i, category in enumerate(thal_vocab_list)}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x5p3OgC-m4TW"
      },
      "outputs": [],
      "source": [
        "# Custom function for converting thal categories to buckets\n",
        "def convert_thal_features(thal_features):\n",
        "  # Note that two examples in the test set are already converted.\n",
        "  return np.array([\n",
        "      thal_map[feature] if feature in thal_vocab_list else feature\n",
        "      for feature in thal_features\n",
        "  ])\n",
        "\n",
        "\n",
        "# Custom function for extracting each feature.\n",
        "def extract_features(dataframe,\n",
        "                     label_name='target',\n",
        "                     feature_names=feature_names):\n",
        "  features = []\n",
        "  for feature_name in feature_names:\n",
        "    if feature_name == 'thal':\n",
        "      features.append(\n",
        "          convert_thal_features(dataframe[feature_name].values).astype(float))\n",
        "    else:\n",
        "      features.append(dataframe[feature_name].values.astype(float))\n",
        "  labels = dataframe[label_name].values.astype(float)\n",
        "  return features, labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7DgoAgkIm8tr"
      },
      "outputs": [],
      "source": [
        "train_xs, train_ys = extract_features(train_dataframe)\n",
        "test_xs, test_ys = extract_features(test_dataframe)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qcguGFRcFgCQ"
      },
      "outputs": [],
      "source": [
        "# Let's define our label minimum and maximum.\n",
        "min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))\n",
        "# Our lattice models may have predictions above 1.0 due to numerical errors.\n",
        "# We can subtract this small epsilon value from our output_max to make sure we\n",
        "# do not predict values outside of our label bound.\n",
        "numerical_error_epsilon = 1e-5"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oyOrtol7mW9r"
      },
      "source": [
        "设置用于在本指南中进行训练的默认值："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ns8pH2AnmgAC"
      },
      "outputs": [],
      "source": [
        "LEARNING_RATE = 0.01\n",
        "BATCH_SIZE = 128\n",
        "NUM_EPOCHS = 500\n",
        "PREFITTING_NUM_EPOCHS = 10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ix2elMrGmiWX"
      },
      "source": [
        "## 特征配置\n",
        "\n",
        "使用 [tfl.configs.FeatureConfig](https://tensorflow.google.cn/lattice/api_docs/python/tfl/configs/FeatureConfig) 设置特征校准和按特征的配置。特征配置包括单调性约束、按特征的正则化（请参阅 [tfl.configs.RegularizerConfig](https://tensorflow.google.cn/lattice/api_docs/python/tfl/configs/RegularizerConfig)）以及点阵模型的点阵大小。\n",
        "\n",
        "请注意，我们必须为希望模型识别的任何特征完全指定特征配置。否则，模型将无法获知存在这样的特征。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WLSfZ5G7-YT_"
      },
      "source": [
        "### 计算分位数\n",
        "\n",
        "尽管 `tfl.configs.FeatureConfig` 中 `pwl_calibration_input_keypoints` 的默认设置为“分位数”，但对于预制模型，我们必须手动定义输入关键点。为此，我们首先定义自己的辅助函数来计算分位数。 "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-LqqEp3k-06d"
      },
      "outputs": [],
      "source": [
        "def compute_quantiles(features,\n",
        "                      num_keypoints=10,\n",
        "                      clip_min=None,\n",
        "                      clip_max=None,\n",
        "                      missing_value=None):\n",
        "  # Clip min and max if desired.\n",
        "  if clip_min is not None:\n",
        "    features = np.maximum(features, clip_min)\n",
        "    features = np.append(features, clip_min)\n",
        "  if clip_max is not None:\n",
        "    features = np.minimum(features, clip_max)\n",
        "    features = np.append(features, clip_max)\n",
        "  # Make features unique.\n",
        "  unique_features = np.unique(features)\n",
        "  # Remove missing values if specified.\n",
        "  if missing_value is not None:\n",
        "    unique_features = np.delete(unique_features,\n",
        "                                np.where(unique_features == missing_value))\n",
        "  # Compute and return quantiles over unique non-missing feature values.\n",
        "  return np.quantile(\n",
        "      unique_features,\n",
        "      np.linspace(0., 1., num=num_keypoints),\n",
        "      interpolation='nearest').astype(float)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ePWXuDH7-1i1"
      },
      "source": [
        "### 定义我们的特征配置\n",
        "\n",
        "现在我们可以计算分位数了，我们为希望模型将其作为输入的每个特征定义一个特征配置。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8y27RmHIrSBn"
      },
      "outputs": [],
      "source": [
        "# Feature configs are used to specify how each feature is calibrated and used.\n",
        "feature_configs = [\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='age',\n",
        "        lattice_size=3,\n",
        "        monotonicity='increasing',\n",
        "        # We must set the keypoints manually.\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            train_xs[feature_name_indices['age']],\n",
        "            num_keypoints=5,\n",
        "            clip_max=100),\n",
        "        # Per feature regularization.\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name='calib_wrinkle', l2=0.1),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='sex',\n",
        "        num_buckets=2,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='cp',\n",
        "        monotonicity='increasing',\n",
        "        # Keypoints that are uniformly spaced.\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        pwl_calibration_input_keypoints=np.linspace(\n",
        "            np.min(train_xs[feature_name_indices['cp']]),\n",
        "            np.max(train_xs[feature_name_indices['cp']]),\n",
        "            num=4),\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='chol',\n",
        "        monotonicity='increasing',\n",
        "        # Explicit input keypoints initialization.\n",
        "        pwl_calibration_input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],\n",
        "        # Calibration can be forced to span the full output range by clamping.\n",
        "        pwl_calibration_clamp_min=True,\n",
        "        pwl_calibration_clamp_max=True,\n",
        "        # Per feature regularization.\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='fbs',\n",
        "        # Partial monotonicity: output(0) &lt;= output(1)\n",
        "        monotonicity=[(0, 1)],\n",
        "        num_buckets=2,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='trestbps',\n",
        "        monotonicity='decreasing',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            train_xs[feature_name_indices['trestbps']], num_keypoints=5),\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='thalach',\n",
        "        monotonicity='decreasing',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            train_xs[feature_name_indices['thalach']], num_keypoints=5),\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='restecg',\n",
        "        # Partial monotonicity: output(0) &lt;= output(1), output(0) &lt;= output(2)\n",
        "        monotonicity=[(0, 1), (0, 2)],\n",
        "        num_buckets=3,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='exang',\n",
        "        # Partial monotonicity: output(0) &lt;= output(1)\n",
        "        monotonicity=[(0, 1)],\n",
        "        num_buckets=2,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='oldpeak',\n",
        "        monotonicity='increasing',\n",
        "        pwl_calibration_num_keypoints=5,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            train_xs[feature_name_indices['oldpeak']], num_keypoints=5),\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='slope',\n",
        "        # Partial monotonicity: output(0) &lt;= output(1), output(1) &lt;= output(2)\n",
        "        monotonicity=[(0, 1), (1, 2)],\n",
        "        num_buckets=3,\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='ca',\n",
        "        monotonicity='increasing',\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        pwl_calibration_input_keypoints=compute_quantiles(\n",
        "            train_xs[feature_name_indices['ca']], num_keypoints=4),\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name='thal',\n",
        "        # Partial monotonicity:\n",
        "        # output(normal) &lt;= output(fixed)\n",
        "        # output(normal) &lt;= output(reversible)\n",
        "        monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],\n",
        "        num_buckets=3,\n",
        "        # We must specify the vocabulary list in order to later set the\n",
        "        # monotonicities since we used names and not indices.\n",
        "        vocabulary_list=thal_vocab_list,\n",
        "    ),\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-XuAnP_-vyK6"
      },
      "source": [
        "接下来，我们需要确保为使用自定义词汇表的特征（例如上面的“thal”）正确设置单调性。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZIn2-EfGv--m"
      },
      "outputs": [],
      "source": [
        "tfl.premade_lib.set_categorical_monotonicities(feature_configs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mx50YgWMcxC4"
      },
      "source": [
        "## 校准线性模型\n",
        "\n",
        "要构造 TFL 预制模型，请首先从 [tfl.configs](https://tensorflow.google.cn/lattice/api_docs/python/tfl/configs) 构造模型配置。使用 [tfl.configs.CalibratedLinearConfig](https://tensorflow.google.cn/lattice/api_docs/python/tfl/configs/CalibratedLinearConfig) 构造校准线性模型。此模型会将分段线性和分类校准应用于输入特征，随后应用线性组合和可选的输出分段线性校准。使用输出校准或指定输出边界时，线性层会将加权平均应用于校准的输入。\n",
        "\n",
        "下面的示例将基于前 5 个特征创建一个校准线性模型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UvMDJKqTc1vC"
      },
      "outputs": [],
      "source": [
        "# Model config defines the model structure for the premade model.\n",
        "linear_model_config = tfl.configs.CalibratedLinearConfig(\n",
        "    feature_configs=feature_configs[:5],\n",
        "    use_bias=True,\n",
        "    # We must set the output min and max to that of the label.\n",
        "    output_min=min_label,\n",
        "    output_max=max_label,\n",
        "    output_calibration=True,\n",
        "    output_calibration_num_keypoints=10,\n",
        "    output_initialization=np.linspace(min_label, max_label, num=10),\n",
        "    regularizer_configs=[\n",
        "        # Regularizer for the output calibrator.\n",
        "        tfl.configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),\n",
        "    ])\n",
        "# A CalibratedLinear premade model constructed from the given model config.\n",
        "linear_model = tfl.premade.CalibratedLinear(linear_model_config)\n",
        "# Let's plot our model.\n",
        "tf.keras.utils.plot_model(linear_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3MC3-AyX00-A"
      },
      "source": [
        "现在，与任何其他 [tf.keras.Model](https://tensorflow.google.cn/api_docs/python/tf/keras/Model) 一样，我们编译该模型并将其拟合到我们的数据中。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hPlEK-yG1B-U"
      },
      "outputs": [],
      "source": [
        "linear_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    metrics=[tf.keras.metrics.AUC()],\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "linear_model.fit(\n",
        "    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OG2ua0MGAkoi"
      },
      "source": [
        "训练完模型后，我们可以在测试集中对其进行评估。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HybGTvXxAoxV"
      },
      "outputs": [],
      "source": [
        "print('Test Set Evaluation...')\n",
        "print(linear_model.evaluate(test_xs, test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jAAJK-wlc15S"
      },
      "source": [
        "## 校准点阵模型\n",
        "\n",
        "使用 [tfl.configs.CalibratedLatticeConfig](https://tensorflow.google.cn/lattice/api_docs/python/tfl/configs/CalibratedLatticeConfig) 构造校准点阵模型。校准点阵模型会将分段线性和分类校准应用于输入特征，随后应用点阵模型和可选的输出分段线性校准。\n",
        "\n",
        "下面的示例将基于前 5 个特征创建一个校准点阵模型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u7gNcrMtc4Lp"
      },
      "outputs": [],
      "source": [
        "# This is a calibrated lattice model: inputs are calibrated, then combined\n",
        "# non-linearly using a lattice layer.\n",
        "lattice_model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=feature_configs[:5],\n",
        "    output_min=min_label,\n",
        "    output_max=max_label - numerical_error_epsilon,\n",
        "    output_initialization=[min_label, max_label],\n",
        "    regularizer_configs=[\n",
        "        # Torsion regularizer applied to the lattice to make it more linear.\n",
        "        tfl.configs.RegularizerConfig(name='torsion', l2=1e-2),\n",
        "        # Globally defined calibration regularizer is applied to all features.\n",
        "        tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-2),\n",
        "    ])\n",
        "# A CalibratedLattice premade model constructed from the given model config.\n",
        "lattice_model = tfl.premade.CalibratedLattice(lattice_model_config)\n",
        "# Let's plot our model.\n",
        "tf.keras.utils.plot_model(lattice_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nmc3TUIIGGoH"
      },
      "source": [
        "和以前一样，我们编译、拟合并评估我们的模型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vIjOQGD2Gp_Z"
      },
      "outputs": [],
      "source": [
        "lattice_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    metrics=[tf.keras.metrics.AUC()],\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "lattice_model.fit(\n",
        "    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(lattice_model.evaluate(test_xs, test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bx74CD4Cc4T3"
      },
      "source": [
        "## 校准点阵集成模型\n",
        "\n",
        "当特征数量很大时，可以使用集成模型，这种模型会为特征的子集创建多个较小的点阵并计算它们的输出平均值，而不是仅创建单个巨大的点阵。使用 [tfl.configs.CalibratedLatticeEnsembleConfig](https://tensorflow.google.cn/lattice/api_docs/python/tfl/configs/CalibratedLatticeEnsembleConfig) 构造集成点阵模型。校准点阵集成模型会将分段线性和分类校准应用于输入特征，随后应用点阵模型的集成和可选的输出分段线性校准。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mbg4lsKqnEkV"
      },
      "source": [
        "### 显式点阵集成初始化\n",
        "\n",
        "如果您已经知道要将哪些特征子集馈入点阵，则可以使用特征名称显式设置点阵。下面的示例将创建一个校准点阵集成模型，此模型具有 5 个点阵，每个点阵有 3 个特征。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yu8Twg8mdJ18"
      },
      "outputs": [],
      "source": [
        "# This is a calibrated lattice ensemble model: inputs are calibrated, then\n",
        "# combined non-linearly and averaged using multiple lattice layers.\n",
        "explicit_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    lattices=[['trestbps', 'chol', 'ca'], ['fbs', 'restecg', 'thal'],\n",
        "              ['fbs', 'cp', 'oldpeak'], ['exang', 'slope', 'thalach'],\n",
        "              ['restecg', 'age', 'sex']],\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3,\n",
        "    output_min=min_label,\n",
        "    output_max=max_label - numerical_error_epsilon,\n",
        "    output_initialization=[min_label, max_label])\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# model config.\n",
        "explicit_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    explicit_ensemble_model_config)\n",
        "# Let's plot our model.\n",
        "tf.keras.utils.plot_model(\n",
        "    explicit_ensemble_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PJYR0i6MMDyh"
      },
      "source": [
        "和以前一样，我们编译、拟合并评估我们的模型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "capt98IOMHEm"
      },
      "outputs": [],
      "source": [
        "explicit_ensemble_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    metrics=[tf.keras.metrics.AUC()],\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "explicit_ensemble_model.fit(\n",
        "    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(explicit_ensemble_model.evaluate(test_xs, test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VnI70C9gdKQw"
      },
      "source": [
        "### 随机点阵集成\n",
        "\n",
        "如果您不确定要将哪些特征子集馈入点阵，则另一个选择是为每个点阵使用随机的特征子集。下面的示例将创建一个校准点阵集成模型，此模型具有 5 个点阵，每个点阵有 3 个特征。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7EhWrQaPIXj8"
      },
      "outputs": [],
      "source": [
        "# This is a calibrated lattice ensemble model: inputs are calibrated, then\n",
        "# combined non-linearly and averaged using multiple lattice layers.\n",
        "random_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    lattices='random',\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3,\n",
        "    output_min=min_label,\n",
        "    output_max=max_label - numerical_error_epsilon,\n",
        "    output_initialization=[min_label, max_label],\n",
        "    random_seed=42)\n",
        "# Now we must set the random lattice structure and construct the model.\n",
        "tfl.premade_lib.set_random_lattice_ensemble(random_ensemble_model_config)\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# model config.\n",
        "random_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    random_ensemble_model_config)\n",
        "# Let's plot our model.\n",
        "tf.keras.utils.plot_model(\n",
        "    random_ensemble_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sbxcIF0PJUDc"
      },
      "source": [
        "和以前一样，我们编译、拟合并评估我们的模型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w0YdCDyGJY1G"
      },
      "outputs": [],
      "source": [
        "random_ensemble_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    metrics=[tf.keras.metrics.AUC()],\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "random_ensemble_model.fit(\n",
        "    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(random_ensemble_model.evaluate(test_xs, test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZhJWe7fZIs4-"
      },
      "source": [
        "### RTL 层随机点阵集成\n",
        "\n",
        "使用随机点阵集成时，可以指定模型使用单个 `tfl.layers.RTL` 层。我们注意到，`tfl.layers.RTL` 仅支持单调性约束，对于所有特征都必须具有相同的点阵大小，并且不包含按特征的正则化。请注意，与使用单独的 `tfl.layers.Lattice` 实例相比，使用 `tfl.layers.RTL` 层可使您扩展到更大的集成。\n",
        "\n",
        "下面的示例将创建一个校准点阵集成模型，此模型具有 5 个点阵，每个点阵有 3 个特征。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0PC9oRFYJMF_"
      },
      "outputs": [],
      "source": [
        "# Make sure our feature configs have the same lattice size, no per-feature\n",
        "# regularization, and only monotonicity constraints.\n",
        "rtl_layer_feature_configs = copy.deepcopy(feature_configs)\n",
        "for feature_config in rtl_layer_feature_configs:\n",
        "  feature_config.lattice_size = 2\n",
        "  feature_config.unimodality = 'none'\n",
        "  feature_config.reflects_trust_in = None\n",
        "  feature_config.dominates = None\n",
        "  feature_config.regularizer_configs = None\n",
        "# This is a calibrated lattice ensemble model: inputs are calibrated, then\n",
        "# combined non-linearly and averaged using multiple lattice layers.\n",
        "rtl_layer_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=rtl_layer_feature_configs,\n",
        "    lattices='rtl_layer',\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3,\n",
        "    output_min=min_label,\n",
        "    output_max=max_label - numerical_error_epsilon,\n",
        "    output_initialization=[min_label, max_label],\n",
        "    random_seed=42)\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# model config. Note that we do not have to specify the lattices by calling\n",
        "# a helper function (like before with random) because the RTL Layer will take\n",
        "# care of that for us.\n",
        "rtl_layer_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    rtl_layer_ensemble_model_config)\n",
        "# Let's plot our model.\n",
        "tf.keras.utils.plot_model(\n",
        "    rtl_layer_ensemble_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yWdxZpS0JWag"
      },
      "source": [
        "和以前一样，我们编译、拟合并评估我们的模型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HQdkkWwqJW8p"
      },
      "outputs": [],
      "source": [
        "rtl_layer_ensemble_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    metrics=[tf.keras.metrics.AUC()],\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "rtl_layer_ensemble_model.fit(\n",
        "    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(rtl_layer_ensemble_model.evaluate(test_xs, test_ys))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A61VpAl8uOiT"
      },
      "source": [
        "### 晶体点阵集成\n",
        "\n",
        "预制模型还提供了一种称为[晶体](https://papers.nips.cc/paper/6377-fast-and-flexible-monotonic-functions-with-ensembles-of-lattices)的启发式特征排列算法。要使用晶体算法，首先训练一个预拟合模型，此模型会估算成对的特征交互。然后，它将对最终集成进行排列，使具有更多非线性交互的特征处于同一点阵中。\n",
        "\n",
        "预制库提供了一些辅助函数，用于构造预拟合模型配置以及提取晶体结构。请注意，预拟合模型不需要完全训练，因此几个周期便已足够。\n",
        "\n",
        "下面的示例将创建一个校准点阵集成模型，此模型具有 5 个点阵，每个点阵有 3 个特征。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yT5eiknCu9sj"
      },
      "outputs": [],
      "source": [
        "# This is a calibrated lattice ensemble model: inputs are calibrated, then\n",
        "# combines non-linearly and averaged using multiple lattice layers.\n",
        "crystals_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(\n",
        "    feature_configs=feature_configs,\n",
        "    lattices='crystals',\n",
        "    num_lattices=5,\n",
        "    lattice_rank=3,\n",
        "    output_min=min_label,\n",
        "    output_max=max_label - numerical_error_epsilon,\n",
        "    output_initialization=[min_label, max_label],\n",
        "    random_seed=42)\n",
        "# Now that we have our model config, we can construct a prefitting model config.\n",
        "prefitting_model_config = tfl.premade_lib.construct_prefitting_model_config(\n",
        "    crystals_ensemble_model_config)\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# prefitting model config.\n",
        "prefitting_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    prefitting_model_config)\n",
        "# We can compile and train our prefitting model as we like.\n",
        "prefitting_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "prefitting_model.fit(\n",
        "    train_xs,\n",
        "    train_ys,\n",
        "    epochs=PREFITTING_NUM_EPOCHS,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    verbose=False)\n",
        "# Now that we have our trained prefitting model, we can extract the crystals.\n",
        "tfl.premade_lib.set_crystals_lattice_ensemble(crystals_ensemble_model_config,\n",
        "                                              prefitting_model_config,\n",
        "                                              prefitting_model)\n",
        "# A CalibratedLatticeEnsemble premade model constructed from the given\n",
        "# model config.\n",
        "crystals_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(\n",
        "    crystals_ensemble_model_config)\n",
        "# Let's plot our model.\n",
        "tf.keras.utils.plot_model(\n",
        "    crystals_ensemble_model, show_layer_names=False, rankdir='LR')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PRLU1z-216h8"
      },
      "source": [
        "和以前一样，我们编译、拟合并评估我们的模型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U73On3v91-Qq"
      },
      "outputs": [],
      "source": [
        "crystals_ensemble_model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(),\n",
        "    metrics=[tf.keras.metrics.AUC()],\n",
        "    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))\n",
        "crystals_ensemble_model.fit(\n",
        "    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)\n",
        "print('Test Set Evaluation...')\n",
        "print(crystals_ensemble_model.evaluate(test_xs, test_ys))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "premade_models.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
